from transformers import LlamaTokenizer,LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel,LlamaConfig,LlamaDecoderLayer,LlamaRMSNorm

tokenizer = LlamaTokenizer.from_pretrained('FlagAlpha/Atom-7B')
text = 'how are you?'
encoded_input = tokenizer(text, return_tensors='pt')
print(encoded_input)
input_ids = encoded_input['input_ids']
print(input_ids.shape)

cfg = LlamaConfig()
cfg.num_hidden_layers = 4
model = LlamaForCausalLM(cfg)
from transformers.utils.fx import symbolic_trace

traced = symbolic_trace(model)
print(traced)
